查看原文
其他

论文回顾 | [ICML 2017] 一种深度网络快速适应的模型无关元学习方法(元学习经典论文)

本文简要介绍ICML2017论文“Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”的主要工作。该论文是元学习领域的经典论文之一,提出了一种可用于任何梯度下降训练的模型的元学习方法,该方法在图像分类、回归和强化学习任务中都有很好的效果。

 

一、研究背景 

目前的人工智能在众多任务中都取得了令人惊讶的表现,但是大部分人工智能系统需要大量的训练数据和较长的训练时间来学习一个任务,并且单个AI模型一般只能处理单一的任务,当其面对一个新的任务时往往束手无策。相比而言,人类智能在面对一个新任务时,往往可以根据以往的知识积累和学习技巧快速的通过少量样本进行学习,那么人工智能是否也可以和人类一样呢?元学习(Meta-learning)就是学习如何学习,它研究的不是如何提升模型的解决某项具体的任务的能力,而是研究如何提升模型解决一系列任务的能力;元学习是让模型学习到各种任务的共性或者学习技巧之后可以通过少量的样本快速的学习一个新任务。这种学习方式主要挑战在于少量的样本情况下如何学习到任务并且不会过拟合,同时如何做到快速的适应新任务。该论文的方法是让模型在众多任务的训练过程中学习到一个模型初始化参数,这个初始化参数可能在每个训练任务中的表现并不是最优,但是基于这个模型参数,模型可以在少量样本中只经过几步梯度下降就可以学习到一个新的任务。

 

二、MAML的算法 

该论文提出了一种通过梯度下降来训练的元学习方法,它可以运用于许多任务中,如图像分类,回归和强化学习。MAML的关键思想是在大量的训练任务中学习到一个适应各种任务的模型初始化参数,使得模型可以通过少量的梯度更新次数就可以快速的学习新的任务。其梯度更新简化图如下图1,中间实线和虚线相交的黑点是希望学习的模型初始化参数,θ为模型参数,不同任务上求出的梯度。为需要学习的新任务的模型参数。



 相比较之前的元学习方法,MAML有以下有优点:

  • 不会增加模型的学习参数;

  • 不会对模型类型有限制,理论上是所有基于梯度下降的模型都可以运用MAML;

  • 可运用于大部分损失函数,如可微的监督损失函数和不可微的强化学习的损失函数。


 

如上图为MAML算法的训练具体流程,首先需要准备一个任务分布P(T)和两个参数更新系数α和β;然后对模型的参数进行随机初始化;然后从任务分布P(T)中抽取出一批任务,对每一个任务使用其对应的K个训练样本进行第一次梯度计算,用得到的梯度计算更新的模型参数(此处只是计算模型更新参数,并不会真实更新模型参数);然后用这一批任务得到模型参数重新计算损失函数并加和,计算第二次梯度用于更新模型参数θ;如此循环迭代3~8行的步骤就可以训练更新得到一个很好的模型初始化参数,这个参数可以表示大量任务的共性。基于训练得到的网络参数,模型就可以在新任务的少量样本上进行几次梯度下降就可以取得很好的泛化性能。

 

三、主要实验结果及可视化效果 

作者的实验基于小样本学习(Few-shot Learning)问题,小样本学习中N-way K-shot是指样本的类别为N个,每个类别有K个样本,在每个任务中包含有N*K个训练样本(Support Set)和未知类别的测试样本(Query Set)。

(1)回归任务

回归任务中,作者采用了两层40大小的隐含层的神经网络来拟合一条正弦函数;不同的幅度和相位的正弦函数代表不同的任务,每条正弦曲线分别采样K(K={5,10,20})个样本点作为训练样本;损失函数为MSE(Mean-squared Error)。除了使用MAML进行模型训练,作者还训练了两个基础模型进行对比:一是在所有任务的训练样本点进行预训练,最后在测试集上进行微调,二是将所有正弦函数真实的相位和幅度作为输入的模型。



如上图2中左侧两图为MAML模型,预训练模型和测试正弦曲线在测试集上的对比,可以看到MAML模型预测的幅度和相位明显优于预训练模型。图2最左则的图显示MAML只需要经过少量梯度迭代就可以学习到一个新的正弦曲线;MAML也可以在样本点只覆盖了正弦曲线周期的一半的情况下依然可以很好的预测出另外一半曲线,说明MAML训练出来的模型是真正在学习正弦曲线,而不是简单的拟合训练样本点。图2右侧两图为预训练模型在不同参数下的测试结果,都表明其不能很好的预测到正弦曲线的幅度和相位。

 

(2)图像分类任务

图像分类任务作者是基于小样本学习(Few-shot Learning),作者采用了四层卷积层加一层全连接层的神经网络来预测图片类别,损失函数采用的为常用的交叉熵(Cross Entropy)。



作者主要在Omniglot和MiniImageNet 小样本学习数据集进行了实验,Omniglot是小样本学习中常用的字符识别数据集,包含了五十个不同字母表的1623个字符。上表表明,相比之前的SOAT的小样本学习的方法,MAML在Omniglot数据集上的分类也同样十分的出色。在MiniImageNet上,MAML训练模型明显的优于预训练的Fine-tuning Baseline,同时测试结果也比其他元学习方法高几个点。
 
(3)强化学习任务
强化学习任务做了2D Navigation和Locomotion两个实验,作者采用两层100节点的隐含层的神经网络来训练策略(Policy)。作者除了使用训练MAML模型外,还使用同样的数据训练三个基础模型进行对比:一是在所有的训练任务上预训练出一种策略(Policy),然后在新任务下进行微调;二是直接随机初始化网络参数,然后在新任务上进行微调。三是把目标位置,目标速度或者目标方向输入给网络。
2D Navigation是在一个二维平面上有一个目标位置,需要给出速度策略使Agent到达目标位置,不同的目标位置代表了不同的任务。其奖励函数为Agent到目标位置的负平方距离。



如上图4所示,当面临一个新目标位置的时候,MAML训练的模型在经过几次梯度更新之后的奖励返回值会明显的大于预训练模型和随机初始化模型,说明MAML可以加快策略梯度(Policy Gradient)强化学习的训练时间;相比较预训练模型,MAML的模型到达的位置距离目标位置更加近。
 
四、总结与思考 
  1. 论文提出了一个简单的模型无关的元学习方法用于训练模型的参数,使得模型经过少量梯度更新次数便可以快速学习新的任务。
  2. 在回归和图像分类的小样本学习任务中表明MAML效果很好,同时可以加速策略梯度强化学习的训练过程。
  3. MAML的缺点:当网络较大时,MAML计算量将会非常的大;MAML的训练任务要求具有一定的相关性;有时候MAML的训练过程不是很稳定。
 
参考文献 

[1] Ravi, Sachin and Larochelle, Hugo. Optimization as a model for few-shot learning. In International Conference on Learning Representations (ICLR), 2017.

[2] Koch, Gregory. Siamese neural networks for one-shot image recognition. ICML Deep Learning Workshop,2015.

[3] Vinyals, Oriol, Blundell,Charles, Lillicrap, Tim, Wierstra, Daan, et al. Matching networks for one shot learning. In Neural Information Processing Systems (NIPS), 2016.

[4] Duan, Yan, Chen, Xi,Houthooft, Rein, Schulman, John, and Abbeel, Pieter. Benchmarking deep reinforcement learning for continuous control. In International Conference on Machine Learning (ICML), 2016a.

[5] Williams, Ronald J. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992.

[6] Schulman, John, Levine,Sergey, Abbeel, Pieter, Jordan, Michael I, and Moritz, Philipp. Trust region policy optimization. In International Conference on Machine Learning (ICML),2015.



原文作者:Chelsea Finn ,Pieter Abbeel , Sergey Levine


撰稿:方传明

编排:高 学

审校:殷 飞

 发布:金连文 


免责声明:1)本文仅代表撰稿者观点,个人理解及总结不一定准确及全面,论文完整思想及论点应以原论文为准。(2)本文观点不代表本公众号立场。 


往期精彩内容回顾



征稿启事:本公众号将不定期介绍文档图像分析与识别及相关领域的论文、数据集、代码等成果,欢迎自荐或推荐相关领域最新论文/代码/数据集等成果给本公众号审阅编排后发布。




(扫描识别如上二维码加关注)



您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存